{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "!pip install -e /dss/dsshome1/04/di93zer/git/cellnet --no-deps"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import seaborn as sns\n",
    "import torch\n",
    "\n",
    "from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, TQDMProgressBar\n",
    "from lightning.pytorch.loggers import TensorBoardLogger\n",
    "from lightning.pytorch.utilities.model_summary import ModelSummary\n",
    "from lightning.pytorch import seed_everything"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "torch.set_float32_matmul_precision('high')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "%load_ext autoreload"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "%autoreload\n",
    "from cellnet.estimators import EstimatorCellTypeClassifier"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Init model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "# config parameters\n",
    "MODEL = 'cxg_2023_05_15_lung_only_tabnet'\n",
    "CHECKPOINT_PATH = os.path.join('/mnt/dssfs02/tb_logs', MODEL)\n",
    "LOGS_PATH = os.path.join('/mnt/dssfs02/tb_logs', MODEL)\n",
    "DATA_PATH = '/mnt/dssmcmlfs01/merlin_cxg_2023_05_15_sf-log1p_lung_only'\n",
    "\n",
    "\n",
    "estim = EstimatorCellTypeClassifier(DATA_PATH)\n",
    "seed_everything(1)\n",
    "estim.init_datamodule(batch_size=2048)\n",
    "estim.init_trainer(\n",
    "    trainer_kwargs={\n",
    "        'max_epochs': 50,\n",
    "        'gradient_clip_val': 1.,\n",
    "        'gradient_clip_algorithm': 'norm',\n",
    "        'default_root_dir': CHECKPOINT_PATH,\n",
    "        'accelerator': 'gpu',\n",
    "        'devices': 1,\n",
    "        'num_sanity_val_steps': 0,\n",
    "        'check_val_every_n_epoch': 2,\n",
    "        'logger': [TensorBoardLogger(LOGS_PATH, name='default', version='version_2_no_augment')],\n",
    "        'log_every_n_steps': 100,\n",
    "        'detect_anomaly': False,\n",
    "        'enable_progress_bar': True,\n",
    "        'enable_model_summary': False,\n",
    "        'enable_checkpointing': True,\n",
    "        'callbacks': [\n",
    "            TQDMProgressBar(refresh_rate=50),\n",
    "            LearningRateMonitor(logging_interval='step'),\n",
    "            ModelCheckpoint(filename='val_f1_macro_{epoch}_{val_f1_macro:.3f}', monitor='val_f1_macro', mode='max',\n",
    "                            every_n_epochs=1, save_top_k=2),\n",
    "            ModelCheckpoint(filename='val_loss_{epoch}_{val_loss:.3f}', monitor='val_loss', mode='min',\n",
    "                            every_n_epochs=1, save_top_k=2)\n",
    "        ],\n",
    "    }\n",
    ")\n",
    "estim.init_model(\n",
    "    model_type='tabnet',\n",
    "    model_kwargs={\n",
    "        'learning_rate': 0.005,\n",
    "        'weight_decay': 0.05,\n",
    "        'lr_scheduler': torch.optim.lr_scheduler.StepLR,\n",
    "        'lr_scheduler_kwargs': {\n",
    "            'step_size': 2,\n",
    "            'gamma': 0.9,\n",
    "            'verbose': True\n",
    "        },\n",
    "        'optimizer': torch.optim.AdamW,\n",
    "        'lambda_sparse': 1e-5,\n",
    "        'n_d': 128,\n",
    "        'n_a': 64,\n",
    "        'n_steps': 1,\n",
    "        'gamma': 1.3,\n",
    "        'n_independent': 7,\n",
    "        'n_shared': 3,\n",
    "        'virtual_batch_size': 256,\n",
    "        'mask_type': 'entmax',\n",
    "        'augment_training_data': False\n",
    "    },\n",
    ")\n",
    "print(ModelSummary(estim.model))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "tags": []
   },
   "source": [
    "# Find learning rate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "lr_find_res = estim.find_lr(lr_find_kwargs={'early_stop_threshold': 10., 'min_lr': 1e-8, 'max_lr': 10., 'num_training': 100})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "tags": []
   },
   "outputs": [],
   "source": [
    "ax = sns.lineplot(x=lr_find_res[1]['lr'], y=lr_find_res[1]['loss'])\n",
    "ax.set_xscale('log')\n",
    "ax.set_ylim(5.25, top=7.)\n",
    "ax.set_xlim(1e-6, 10.)\n",
    "print(f'Suggested learning rate: {lr_find_res[0]}')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Fit model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    },
    "tags": []
   },
   "outputs": [],
   "source": [
    "estim.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
